Image to Captions --- Show and Tell Model with Cats and Dogs

Show and Tell Model --- Given an image, the model generates captions.

Generated captions:

  • a cat laying on top of a rug next to a cat
  • a cat laying on the floor next to a cat
  • a cat laying on top of a rug next to a cat

The "Show and Tell" model presented in this notebook is based on work in https://github.com/tensorflow/models/tree/master/im2txt. Modifications are made to make training much faster (from one week with GPU to a few hours with CPU only). More specifically, the following modifications are made:

  1. Only cats and dogs images are used so only about 8% of data are used.
  2. We pre-generate inception embeddings once, instead of doing it during training to reduce the time it needs to generate embeddings from images. The quality is equivalent to the first phase in the github example above.

The following diagram illustrates the model architecture (For details, see show and tell model on github).

Send any feedback to datalab-feedback@google.com.

Requirement

150 GB disk. n1-standard-1 VM is probably not enough. Recommend high-mem VM types. If you use "datalab create" command to create the Datalab instance, I would suggest high memory VMs by adding "--machine-type n1-highmem-2" option. See https://cloud.google.com/datalab/docs/how-to/machine-type for instructions.

Download Data

We will use MSCOCO data. Although we only use the cats and dogs related images and captions, we need to download the zip packages with full data.


In [ ]:
# Download Images Data

!mkdir -p /content/datalab/img2txt/images
!wget -P /content/datalab/img2txt/ http://msvocds.blob.core.windows.net/coco2014/train2014.zip
!wget -P /content/datalab/img2txt/ http://msvocds.blob.core.windows.net/coco2014/val2014.zip
!unzip -q -j /content/datalab/img2txt/train2014.zip -d /content/datalab/img2txt/images
!unzip -q -j /content/datalab/img2txt/val2014.zip -d /content/datalab/img2txt/images

In [3]:
# Download Captions Data

!wget -P /content/datalab/img2txt/ http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip
!unzip -q -j /content/datalab/img2txt/captions_train-val2014.zip -d /content/datalab/img2txt/

Common Code


In [1]:
from datetime import datetime
from random import randint
import os
import shutil
import six
import tempfile
import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_arg_scope
import yaml

In [2]:
def save_vocab(word_to_id, vocab_file):
    """Save vocabulary to file."""

    with tf.gfile.Open(vocab_file, 'w') as fw:
        yaml.dump(word_to_id, fw, default_flow_style=False)


def load_vocab(vocab_file):
    """Load vocabulary from file."""
    
    with tf.gfile.Open(vocab_file, 'r') as fr:
        return yaml.load(fr)  


def get_instances_size(file_pattern):
    """Count training instances from tf.example file."""

    c = sum(1 for x in tf.python_io.tf_record_iterator(file_pattern))
    print('instances size is %d' % c)
    return c

In [3]:
INCEPTION_V3_CHECKPOINT = 'gs://cloud-ml-data/img/flower_photos/inception_v3_2016_08_28.ckpt'
INCEPTION_EXCLUDED_VARIABLES = ['InceptionV3/AuxLogits', 'InceptionV3/Logits', 'global_step']


def make_batches(iterable, n):
    """Make batches with iterable."""

    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]


def build_image_processing(image_str_tensor):
    """Create image-to-embeddings tf graph."""

    def _decode_and_resize(image_str_tensor):
        """Decodes jpeg string, resizes it and returns a uint8 tensor."""

        # These constants are set by Inception v3's expectations.
        height = 299
        width = 299
        channels = 3

        image = tf.image.decode_jpeg(image_str_tensor, channels=channels)
        image = tf.expand_dims(image, 0)
        image = tf.image.resize_bilinear(image, [height, width], align_corners=False)
        image = tf.squeeze(image, squeeze_dims=[0])
        image = tf.cast(image, dtype=tf.uint8)
        return image

    image = tf.map_fn(_decode_and_resize, image_str_tensor, back_prop=False, dtype=tf.uint8)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = tf.subtract(image, 0.5)
    inception_input = tf.multiply(image, 2.0)

    # Build Inception layers, which expect a tensor of type float from [-1, 1)
    # and shape [batch_size, height, width, channels].
    with tf.contrib.slim.arg_scope(inception_v3_arg_scope()):
        _, end_points = inception_v3(inception_input, is_training=False)    
    embeddings = end_points['PreLogits']
    inception_embeddings = tf.squeeze(embeddings, [1, 2], name='SpatialSqueeze')
    return inception_embeddings


def load_inception_checkpoint(sess, vars_to_restore, checkpoint_path=None):
    """Loal inception checkpoint to session."""

    saver = tf.train.Saver(vars_to_restore)
    if checkpoint_path is None:
        checkpoint_dir = tempfile.mkdtemp()
        try:
            checkpoint_tmp = os.path.join(checkpoint_dir, 'checkpoint')    
            with tf.gfile.Open(INCEPTION_V3_CHECKPOINT, 'r') as f_in, tf.gfile.Open(checkpoint_tmp, 'w') as f_out:
                f_out.write(f_in.read())

            saver.restore(sess, checkpoint_tmp)
        finally:
            shutil.rmtree(checkpoint_dir)
    else:
        saver.restore(sess, checkpoint_path)

Data Preprocessing


In [4]:
# Extract vocabs, images files, captions that are only related to cats and dogs.

from collections import Counter
import six

# If empty, all data is included. Otherwise, include only images with any of the words in its captions.
KEYWORDS = {'cat', 'cats', 'kitten', 'kittens', 'dog', 'dogs', 'puppy', 'puppies'}

# Sentence start, sentence end, and unknown word.
CONTROL_WORDS = ['<s>', '</s>', '<unk>']


def extract(train_content, val_content):
    """Extract vocab, captions, and image files from raw data.
    
    Returns:
      A tuple of the following
        - Vocab: in the form of word_to_id dict.
        - id_wids: A dictionary with key an id, and value a list of captions, where each caption is
                   represented by a list of word ids.
        - id_imagefiles: A dictionary with key an id, and value a path of image file.
    """

    id_captions = [(x['image_id'], x['caption']) for x in train_content['annotations']]
    id_captions += [(x['image_id'], x['caption']) for x in val_content['annotations']]
    id_captions = [(k, v.replace('.', '').replace(',', '').lower().split()) for k, v in id_captions]

    # key - id, value - a list of captions
    id_captions_filtered = {}
    for x in id_captions:
        if not KEYWORDS or (KEYWORDS & set(x[1])):
            id_captions_filtered.setdefault(x[0], []).append(x[1])

    print('number of captions is %d' % sum(len(x) for x in id_captions_filtered.values()))

    words = [w for captions in id_captions_filtered.values() for caption in captions for w in caption]
    counts = Counter(words)
    counts = [x for x in counts.items() if x[1] > 5]
    counts = sorted(counts, key=lambda x: (x[1]), reverse=True)
    counts += [(x, 0) for x in CONTROL_WORDS]

    word_to_id = {str(word_cnt_pair[0]): idx for idx, word_cnt_pair in enumerate(counts)}
    print('vocab size is %d' % len(word_to_id))
    
    id_wids = {}
    for k, v in six.iteritems(id_captions_filtered):
        sentences = []
        for caption in v:
            wids = [word_to_id[x] if x in word_to_id else word_to_id['<unk>'] for x in caption]
            wids = [word_to_id['<s>']] + wids + [word_to_id['</s>']]
            sentences.append(wids)
        id_wids[k] = sentences

    id_imagefiles = {x['id']: x['file_name'] for x in train_content['images']}
    id_imagefiles.update({x['id']: x['file_name'] for x in val_content['images']})    
    id_imagefiles_filtered = {k: v for k, v in six.iteritems(id_imagefiles) if k in id_wids}
    print('number of images is %d' % len(id_imagefiles_filtered))

    return word_to_id, id_wids, id_imagefiles_filtered

In [5]:
# Load data from files.

import json

with open('/content/datalab/img2txt/captions_val2014.json', 'r') as f:
    val_content = json.load(f)

with open('/content/datalab/img2txt/captions_train2014.json', 'r') as f:
    train_content = json.load(f)

word_to_id, id_wids, id_imagefiles = extract(train_content, val_content)


number of captions is 40718
vocab size is 1936
number of images is 9497

In [6]:
# Save the vocab so we can convert word ids to words in prediction.
save_vocab(word_to_id, '/content/datalab/img2txt/vocab.yaml')

In [7]:
def transform(id_imagefiles, id_wids, image_dir, output_dir, train_filename, eval_filename, test_filename, batch_size):
    """Convert images into embeddings, join with captions by id, splits results into train/eval/test,
       and save to tf SequenceExample file.
       
       Note that train/eval data will be SequenceExample, but test data will be text
       (a list of image file paths) because the final model expects raw images as input.
    """

    def _int64_feature(value):
        """Wrapper for inserting an int64 Feature into a SequenceExample proto."""

        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

    def _float_feature(value):
        """Wrapper for inserting an int64 Feature into a SequenceExample proto."""
        
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))

    def _int64_feature_list(values):
        """Wrapper for inserting an int64 FeatureList into a SequenceExample proto."""
        
        return tf.train.FeatureList(feature=[_int64_feature(v) for v in values])


    tf.gfile.MakeDirs(output_dir)
    g = tf.Graph()
    with g.as_default():
        image_str_tensor = tf.placeholder(tf.string, shape=None)
        inception_embeddings = build_image_processing(image_str_tensor)
        vars_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=INCEPTION_EXCLUDED_VARIABLES)
        
    with tf.Session(graph=g) as sess:
        load_inception_checkpoint(sess, vars_to_restore)

        # Write to tf.example files.
        train_file = os.path.join(output_dir, train_filename)        
        eval_file = os.path.join(output_dir, eval_filename)
        writer_train = tf.python_io.TFRecordWriter(train_file)        
        writer_eval = tf.python_io.TFRecordWriter(eval_file)
        writer_test = tf.gfile.Open(os.path.join(output_dir, test_filename), 'w')
        batches = make_batches(list(six.iteritems(id_imagefiles)), batch_size)
        num_of_batches = len(id_imagefiles) / batch_size + 1
        for batch_num, b in enumerate(batches):
            start = datetime.now()
            image_bytes = []
            for img in b:
                with tf.gfile.Open(os.path.join(image_dir, img[1]), 'r') as f:
                    image_bytes.append(f.read())

            embs = sess.run(inception_embeddings, feed_dict={image_str_tensor: image_bytes}) 
            for img, emb in zip(b, embs):
                rnd_num = randint(0, 100)
                # 5% eval, 5% test, 90% training
                if rnd_num > 4:
                    writer = writer_train if rnd_num > 9 else writer_eval
                    img_id = img[0]
                    for caption_wids in id_wids[img_id]:
                        context = tf.train.Features(feature={"id": _int64_feature(img_id), "emb": _float_feature(emb.tolist())})
                        feature_lists = tf.train.FeatureLists(feature_list={"wids": _int64_feature_list(caption_wids)})
                        sequence_example = tf.train.SequenceExample(context=context, feature_lists=feature_lists)
                        writer.write(sequence_example.SerializeToString())
                else:
                    writer_test.write('%d:%s\n' % (img[0], img[1]))
            elapsed = datetime.now() - start
            print('processed batch %d of %d in %s' % (batch_num, num_of_batches, str(elapsed)))
        writer_train.close()   
        writer_eval.close()   
        writer_test.close()

In [14]:
transform(
    id_imagefiles,
    id_wids,
    image_dir='/content/datalab/img2txt/images',
    output_dir='/content/datalab/img2txt/transformed',
    train_filename='train',
    eval_filename='eval',
    test_filename='test.txt',
    batch_size=500)


INFO:tensorflow:Restoring parameters from /tmp/tmpvgOi5p/checkpoint
processed batch 0 of 19 in 0:01:13.046704
processed batch 1 of 19 in 0:01:10.849085
processed batch 2 of 19 in 0:01:11.665896
processed batch 3 of 19 in 0:01:10.556047
processed batch 4 of 19 in 0:01:09.216518
processed batch 5 of 19 in 0:01:10.277492
processed batch 6 of 19 in 0:01:09.650364
processed batch 7 of 19 in 0:01:08.967149
processed batch 8 of 19 in 0:01:08.860686
processed batch 9 of 19 in 0:01:09.274153
processed batch 10 of 19 in 0:01:10.445347
processed batch 11 of 19 in 0:01:08.872389
processed batch 12 of 19 in 0:01:09.499758
processed batch 13 of 19 in 0:01:10.078055
processed batch 14 of 19 in 0:01:09.691590
processed batch 15 of 19 in 0:01:09.614510
processed batch 16 of 19 in 0:01:10.702668
processed batch 17 of 19 in 0:01:09.337728
processed batch 18 of 19 in 0:01:09.780659

In [16]:
!ls /content/datalab/img2txt/transformed -l -h


total 309M
-rw-r--r-- 1 root root  16M Aug 24 19:36 eval
-rw-r--r-- 1 root root  18K Aug 24 19:36 test.txt
-rw-r--r-- 1 root root 293M Aug 24 19:36 train

Training

Helper functions


In [24]:
import tensorflow as tf


def parse_sequence_example(serialized):
    """Parses a tensorflow.SequenceExample into an image and caption.
    Args:
        serialized: A scalar string Tensor; a single serialized SequenceExample.
    Returns:
        id: a scalar integer Tensor.
        emb: image embeddings, a 1-D Tensor with shape [2048].
        wids: word ids, a 1-D Tensor with shape [None].
    """

    context, sequence = tf.parse_single_sequence_example(
        serialized,
        context_features={
            'id': tf.FixedLenFeature([], dtype=tf.int64),
            'emb': tf.FixedLenFeature([2048], dtype=tf.float32)
        },
        sequence_features={
            'wids': tf.FixedLenSequenceFeature([], dtype=tf.int64),
        })

    return context['id'], context['emb'], sequence['wids']


def prefetch_input_data(file_pattern, batch_size):
    """Prefetches string values from disk vocab_idvocab_idinto an input queue.

    Args:
        file_pattern: file patterns (e.g. /tmp/train_data-?????-of-00100).
        batch_size: Model batch size used to determine queue capacity.
    Returns:
        A Queue containing prefetched string values.
    """

    data_files = tf.gfile.Glob(file_pattern)
    filename_queue = tf.train.string_input_producer(data_files, shuffle=True, capacity=16, name='filename_queue')
    capacity = 1000 + 100 * batch_size
    values_queue = tf.RandomShuffleQueue(
        capacity=capacity,
        min_after_dequeue=1000,
        dtypes=[tf.string],
        name="random_input_queue")

    enqueue_ops = []
    reader = tf.TFRecordReader()    
    _, value = reader.read(filename_queue)
    enqueue_ops.append(values_queue.enqueue([value]))
    tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(values_queue, enqueue_ops))

    return values_queue


def build_graph(serialized_sequence_example, vocab_size, train_batch_size, embedding_size, lstm_size, mode):
    """ Build the main TensorFlow graph that will be shared by training and evaluation.
    """
    
    uniform_initializer = tf.random_uniform_initializer(minval=-0.08, maxval=0.08)
    id, img_emb, wids = parse_sequence_example(serialized_sequence_example)
    caption_length = tf.shape(wids)[0]
    input_length = tf.expand_dims(tf.subtract(caption_length, 1), 0)
    input_seq = tf.slice(wids, [0], input_length)
    target_seq = tf.slice(wids, [1], input_length)
    indicator = tf.ones(input_length, dtype=tf.int32)
    enqueue_list = [[img_emb, input_seq, target_seq, indicator]]
    img_embs, input_seqs, target_seqs, input_mask = tf.train.batch_join(
        enqueue_list,
        batch_size=train_batch_size,
        capacity=train_batch_size * 2,
        dynamic_pad=True,
        name="batch_and_pad")
    
    with tf.variable_scope("seq_embedding"), tf.device("/cpu:0"):
        embedding_map = tf.get_variable(
            name="map",
            shape=[vocab_size, embedding_size], initializer=uniform_initializer)
        seq_embeddings = tf.nn.embedding_lookup(embedding_map, input_seqs)

    with tf.variable_scope("image_embedding") as scope:
        image_embeddings = tf.contrib.layers.fully_connected(
            inputs=img_embs,
            num_outputs=embedding_size,
            activation_fn=None,
            weights_initializer=uniform_initializer,
            biases_initializer=None,
            scope=scope)        
    
    lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=lstm_size, state_is_tuple=True)
    if mode == 'train':
        lstm_cell = tf.contrib.rnn.DropoutWrapper(lstm_cell, input_keep_prob=0.7, output_keep_prob=0.7)

    with tf.variable_scope("lstm", initializer=tf.random_uniform_initializer(minval=-0.08, maxval=0.08)) as lstm_scope:
        zero_state = lstm_cell.zero_state(batch_size=image_embeddings.get_shape()[0], dtype=tf.float32)
        # Use image_embeddings as initial state.
        _, initial_state = lstm_cell(image_embeddings, zero_state)
        lstm_scope.reuse_variables()
        sequence_length = tf.reduce_sum(input_mask, 1)
        lstm_outputs, _ = tf.nn.dynamic_rnn(cell=lstm_cell,
                                            inputs=seq_embeddings,
                                            sequence_length=sequence_length,
                                            initial_state=initial_state,
                                            dtype=tf.float32,
                                            scope=lstm_scope)
        
    # lstm_outputs's dim is [batch_size, max_seq_length, lstm_cell.output_size]
    # Reshape it to 2D Tensor [batch * max_seq_length, lstm_cell.output_size] for loss computation.
    lstm_outputs = tf.reshape(lstm_outputs, [-1, lstm_cell.output_size])
    with tf.variable_scope("logits") as logits_scope:
        logits = tf.contrib.layers.fully_connected(
            inputs=lstm_outputs,
            num_outputs=vocab_size,
            activation_fn=None,
            weights_initializer=uniform_initializer,
            scope=logits_scope)
    
    # Similarly, reshape targets to [batch * max_seq_length]
    targets = tf.reshape(target_seqs, [-1])
    
    weights = tf.to_float(tf.reshape(input_mask, [-1]))
    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, logits=logits)
    batch_loss = tf.div(tf.reduce_sum(tf.multiply(losses, weights)), tf.reduce_sum(weights), name="batch_loss") 
    tf.summary.scalar("losses/batch_loss", batch_loss)

    global_step = tf.Variable(
        initial_value=0,
        name="global_step",
        trainable=False,
        collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])
    
    return batch_loss, losses, weights, global_step

In [82]:
def train_graph(vocab_size, train_batch_size, training_file_pattern, embedding_size=1024, lstm_size=512):
    """Build the training graph."""
    
    train_instances_size = get_instances_size(training_file_pattern)
    g = tf.Graph()
    with g.as_default():
        input_queue = prefetch_input_data(training_file_pattern, batch_size=train_batch_size)
        serialized_sequence_example = input_queue.dequeue()
        total_loss, _, _, global_step = build_graph(
            serialized_sequence_example, vocab_size, train_batch_size, embedding_size, lstm_size, 'train')
        learning_rate = tf.constant(2.0)   # initial_learning_rate
        learning_rate_decay_factor = 0.5
        num_batches_per_epoch = (train_instances_size / train_batch_size)
        decay_steps = int(num_batches_per_epoch * 8)    # num_epochs_per_decay
    
        def _learning_rate_decay_fn(learning_rate, global_step):
            return tf.train.exponential_decay(
                learning_rate,
                global_step,
                decay_steps=decay_steps,
                decay_rate=learning_rate_decay_factor,
                staircase=True)
      
        train_op = tf.contrib.layers.optimize_loss(
            loss=total_loss,
            global_step=global_step,
            learning_rate=learning_rate,
            optimizer='SGD',
            clip_gradients=5.0,
            learning_rate_decay_fn=_learning_rate_decay_fn)   
        saver = tf.train.Saver(max_to_keep=5)
            
    return g, train_op, global_step, saver

Start training


In [26]:
# Remove previous trained model
!rm -r -f /content/datalab/img2txt/train

In [27]:
vocab = load_vocab('/content/datalab/img2txt/vocab.yaml')
vocab_size = len(vocab)

graph, train_op, global_step, saver = train_graph(
    vocab_size, 
    train_batch_size=64,
    training_file_pattern='/content/datalab/img2txt/transformed/train')

tf.contrib.slim.learning.train(
    train_op,
    '/content/datalab/img2txt/train',
    log_every_n_steps=100,
    graph=graph,
    global_step=global_step,
    number_of_steps=10000,
    saver=saver)

# Save inception checkpoint with the model.
inception_checkpoint = os.path.join('/content/datalab/img2txt/train', 'inception_checkpoint')    
with tf.gfile.Open(INCEPTION_V3_CHECKPOINT, 'r') as f_in, tf.gfile.Open(inception_checkpoint, 'w') as f_out:
    f_out.write(f_in.read())


instances size is 36771
INFO:tensorflow:Starting Session.
INFO:tensorflow:Saving checkpoint to path /content/datalab/img2txt/train/model.ckpt
INFO:tensorflow:Starting Queues.
INFO:tensorflow:global_step/sec: 0
INFO:tensorflow:Recording summary at step 0.
INFO:tensorflow:global step 100: loss = 4.0970 (0.752 sec/step)
INFO:tensorflow:global step 200: loss = 3.7678 (0.449 sec/step)
INFO:tensorflow:global step 300: loss = 3.3628 (0.400 sec/step)
INFO:tensorflow:global step 400: loss = 3.4072 (0.619 sec/step)
INFO:tensorflow:global step 500: loss = 3.2066 (0.710 sec/step)
INFO:tensorflow:global step 600: loss = 3.0650 (0.500 sec/step)
INFO:tensorflow:global step 700: loss = 2.8434 (0.416 sec/step)
INFO:tensorflow:global step 800: loss = 2.9489 (0.465 sec/step)
INFO:tensorflow:global step 900: loss = 2.8098 (0.407 sec/step)
INFO:tensorflow:global step 1000: loss = 2.8676 (0.516 sec/step)
INFO:tensorflow:global step 1100: loss = 2.7737 (0.482 sec/step)
INFO:tensorflow:Saving checkpoint to path /content/datalab/img2txt/train/model.ckpt
INFO:tensorflow:global_step/sec: 1.98337
INFO:tensorflow:Recording summary at step 1191.
INFO:tensorflow:global step 1200: loss = 2.6411 (0.443 sec/step)
INFO:tensorflow:global step 1300: loss = 2.8988 (0.564 sec/step)
INFO:tensorflow:global step 1400: loss = 2.7398 (0.390 sec/step)
INFO:tensorflow:global step 1500: loss = 2.6563 (0.395 sec/step)
INFO:tensorflow:global step 1600: loss = 2.5766 (0.574 sec/step)
INFO:tensorflow:global step 1700: loss = 2.6754 (0.894 sec/step)
INFO:tensorflow:global step 1800: loss = 2.6117 (0.499 sec/step)
INFO:tensorflow:global step 1900: loss = 2.4784 (0.430 sec/step)
INFO:tensorflow:global step 2000: loss = 2.4294 (0.394 sec/step)
INFO:tensorflow:global step 2100: loss = 2.5377 (0.566 sec/step)
INFO:tensorflow:global step 2200: loss = 2.3324 (0.464 sec/step)
INFO:tensorflow:global step 2300: loss = 2.4854 (0.368 sec/step)
INFO:tensorflow:Saving checkpoint to path /content/datalab/img2txt/train/model.ckpt
INFO:tensorflow:global_step/sec: 1.97335
INFO:tensorflow:Recording summary at step 2374.
INFO:tensorflow:global step 2400: loss = 2.2752 (0.526 sec/step)
INFO:tensorflow:global step 2500: loss = 2.4420 (0.426 sec/step)
INFO:tensorflow:global step 2600: loss = 2.3908 (0.522 sec/step)
INFO:tensorflow:global step 2700: loss = 2.3991 (0.415 sec/step)
INFO:tensorflow:global step 2800: loss = 2.3973 (0.359 sec/step)
INFO:tensorflow:global step 2900: loss = 2.3564 (0.517 sec/step)
INFO:tensorflow:global step 3000: loss = 2.3111 (0.419 sec/step)
INFO:tensorflow:global step 3100: loss = 2.2128 (0.385 sec/step)
INFO:tensorflow:global step 3200: loss = 2.3527 (0.873 sec/step)
INFO:tensorflow:global step 3300: loss = 2.2904 (0.478 sec/step)
INFO:tensorflow:global step 3400: loss = 2.3108 (0.375 sec/step)
INFO:tensorflow:global step 3500: loss = 2.2632 (0.586 sec/step)
INFO:tensorflow:Saving checkpoint to path /content/datalab/img2txt/train/model.ckpt
INFO:tensorflow:global_step/sec: 2.01167
INFO:tensorflow:Recording summary at step 3582.
INFO:tensorflow:global step 3600: loss = 2.1685 (0.589 sec/step)
INFO:tensorflow:global step 3700: loss = 2.2252 (0.481 sec/step)
INFO:tensorflow:global step 3800: loss = 2.3652 (0.548 sec/step)
INFO:tensorflow:global step 3900: loss = 2.2816 (0.655 sec/step)
INFO:tensorflow:global step 4000: loss = 2.1979 (0.404 sec/step)
INFO:tensorflow:global step 4100: loss = 2.1321 (0.584 sec/step)
INFO:tensorflow:global step 4200: loss = 2.3679 (0.405 sec/step)
INFO:tensorflow:global step 4300: loss = 2.2094 (0.889 sec/step)
INFO:tensorflow:global step 4400: loss = 2.1881 (0.510 sec/step)
INFO:tensorflow:global step 4500: loss = 2.1080 (0.427 sec/step)
INFO:tensorflow:global step 4600: loss = 2.1101 (0.415 sec/step)
INFO:tensorflow:global step 4700: loss = 2.1315 (0.443 sec/step)
INFO:tensorflow:global step 4800: loss = 2.0747 (0.361 sec/step)
INFO:tensorflow:Saving checkpoint to path /content/datalab/img2txt/train/model.ckpt
INFO:tensorflow:global_step/sec: 2.04167
INFO:tensorflow:Recording summary at step 4807.
INFO:tensorflow:global step 4900: loss = 2.0688 (0.470 sec/step)
INFO:tensorflow:global step 5000: loss = 2.1204 (0.371 sec/step)
INFO:tensorflow:global step 5100: loss = 2.1013 (0.559 sec/step)
INFO:tensorflow:global step 5200: loss = 2.0700 (0.517 sec/step)
INFO:tensorflow:global step 5300: loss = 2.0562 (0.410 sec/step)
INFO:tensorflow:global step 5400: loss = 2.1847 (0.552 sec/step)
INFO:tensorflow:global step 5500: loss = 2.0911 (0.530 sec/step)
INFO:tensorflow:global step 5600: loss = 2.0330 (0.480 sec/step)
INFO:tensorflow:global step 5700: loss = 2.0791 (0.561 sec/step)
INFO:tensorflow:global step 5800: loss = 2.0343 (0.582 sec/step)
INFO:tensorflow:global step 5900: loss = 2.0457 (0.433 sec/step)
INFO:tensorflow:global step 6000: loss = 1.9746 (0.650 sec/step)
INFO:tensorflow:Saving checkpoint to path /content/datalab/img2txt/train/model.ckpt
INFO:tensorflow:global_step/sec: 2.05166
INFO:tensorflow:Recording summary at step 6037.
INFO:tensorflow:global step 6100: loss = 1.9113 (0.597 sec/step)
INFO:tensorflow:global step 6200: loss = 2.3442 (0.803 sec/step)
INFO:tensorflow:global step 6300: loss = 2.0370 (0.514 sec/step)
INFO:tensorflow:global step 6400: loss = 2.0980 (0.437 sec/step)
INFO:tensorflow:global step 6500: loss = 2.0921 (0.539 sec/step)
INFO:tensorflow:global step 6600: loss = 2.0364 (0.460 sec/step)
INFO:tensorflow:global step 6700: loss = 1.8723 (0.654 sec/step)
INFO:tensorflow:global step 6800: loss = 2.1641 (0.506 sec/step)
INFO:tensorflow:global step 6900: loss = 1.7602 (0.432 sec/step)
INFO:tensorflow:global step 7000: loss = 1.9271 (0.457 sec/step)
INFO:tensorflow:global step 7100: loss = 1.9710 (0.350 sec/step)
INFO:tensorflow:global step 7200: loss = 1.9334 (0.611 sec/step)
INFO:tensorflow:Saving checkpoint to path /content/datalab/img2txt/train/model.ckpt
INFO:tensorflow:global_step/sec: 2.04335
INFO:tensorflow:Recording summary at step 7263.
INFO:tensorflow:global step 7300: loss = 2.0254 (0.549 sec/step)
INFO:tensorflow:global step 7400: loss = 2.0725 (0.515 sec/step)
INFO:tensorflow:global step 7500: loss = 2.2280 (0.452 sec/step)
INFO:tensorflow:global step 7600: loss = 1.8325 (0.470 sec/step)
INFO:tensorflow:global step 7700: loss = 2.1343 (0.492 sec/step)
INFO:tensorflow:global step 7800: loss = 1.9464 (0.473 sec/step)
INFO:tensorflow:global step 7900: loss = 2.0577 (0.495 sec/step)
INFO:tensorflow:global step 8000: loss = 1.8358 (0.413 sec/step)
INFO:tensorflow:global step 8100: loss = 2.0890 (0.456 sec/step)
INFO:tensorflow:global step 8200: loss = 1.8687 (0.480 sec/step)
INFO:tensorflow:global step 8300: loss = 1.8032 (0.377 sec/step)
INFO:tensorflow:global step 8400: loss = 1.8096 (0.487 sec/step)
INFO:tensorflow:global step 8500: loss = 2.0428 (0.563 sec/step)
INFO:tensorflow:Saving checkpoint to path /content/datalab/img2txt/train/model.ckpt
INFO:tensorflow:global_step/sec: 2.085
INFO:tensorflow:Recording summary at step 8515.
INFO:tensorflow:global step 8600: loss = 1.8879 (0.477 sec/step)
INFO:tensorflow:global step 8700: loss = 1.9199 (0.456 sec/step)
INFO:tensorflow:global step 8800: loss = 2.0376 (0.519 sec/step)
INFO:tensorflow:global step 8900: loss = 1.8503 (0.368 sec/step)
INFO:tensorflow:global step 9000: loss = 1.6652 (0.960 sec/step)
INFO:tensorflow:global step 9100: loss = 1.7716 (0.516 sec/step)
INFO:tensorflow:global step 9200: loss = 1.9044 (0.524 sec/step)
INFO:tensorflow:global step 9300: loss = 1.8633 (0.524 sec/step)
INFO:tensorflow:global step 9400: loss = 1.8051 (0.531 sec/step)
INFO:tensorflow:global step 9500: loss = 1.9589 (0.535 sec/step)
INFO:tensorflow:global step 9600: loss = 1.7757 (0.476 sec/step)
INFO:tensorflow:global step 9700: loss = 1.8710 (0.416 sec/step)
INFO:tensorflow:Saving checkpoint to path /content/datalab/img2txt/train/model.ckpt
INFO:tensorflow:global_step/sec: 2.07832
INFO:tensorflow:Recording summary at step 9761.
INFO:tensorflow:global step 9800: loss = 1.9027 (0.552 sec/step)
INFO:tensorflow:global step 9900: loss = 1.8002 (0.572 sec/step)
INFO:tensorflow:global step 10000: loss = 1.9014 (0.540 sec/step)
INFO:tensorflow:Stopping Training.
INFO:tensorflow:Finished training! Saving model to disk.

Check training loss


In [28]:
from google.datalab.ml import Summary

summary = Summary('/content/datalab/img2txt/train')
summary.list_events()


Out[28]:
{u'OptimizeLoss/learning_rate': {'/content/datalab/img2txt/train'},
 u'OptimizeLoss/loss': {'/content/datalab/img2txt/train'},
 u'batch_and_pad/fraction_of_128_full': {'/content/datalab/img2txt/train'},
 u'filename_queue/fraction_of_16_full': {'/content/datalab/img2txt/train'},
 u'global_step/sec': {'/content/datalab/img2txt/train'},
 u'losses/batch_loss': {'/content/datalab/img2txt/train'}}

In [29]:
summary.plot('losses/batch_loss')


Evaluate


In [80]:
import math
import numpy as np


def eval_graph(vocab_size, eval_batch_size, eval_file_pattern, embedding_size=1024, lstm_size=512):
    """Build evaluation graph."""

    g = tf.Graph()
    with g.as_default():
        input_queue = prefetch_input_data(eval_file_pattern, batch_size=eval_batch_size)
        serialized_sequence_example = input_queue.dequeue()
        _, losses, weights, global_step = build_graph(serialized_sequence_example, vocab_size, eval_batch_size, embedding_size, lstm_size, 'eval')
        saver = tf.train.Saver()
    return g, losses, weights, global_step, saver


def eval_model(vocab_size, train_dir, eval_file_pattern, eval_batch_size=64):
    """Evaluate a trained model with evaluation data."""
    
    eval_instances_size = get_instances_size(eval_file_pattern)
    graph, losses, weights, global_step, saver = eval_graph(vocab_size, eval_batch_size=64, eval_file_pattern=eval_file_pattern)
    checkpoint = tf.train.latest_checkpoint(train_dir)
    with tf.Session(graph=graph) as sess:
        saver.restore(sess, checkpoint)
        global_step_val = tf.train.global_step(sess, global_step.name)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        num_eval_batches = int(math.ceil(eval_instances_size / eval_batch_size))
        
        sum_losses = 0.
        sum_weights = 0.
        for i in xrange(num_eval_batches):
            losses_val, weights_val = sess.run([losses, weights])
            sum_losses += np.sum(losses_val * weights_val)
            sum_weights += np.sum(weights_val)
            if i % 10 == 0:
                tf.logging.info("Computed losses for %d of %d batches.", i + 1, num_eval_batches)

        perplexity = math.exp(sum_losses / sum_weights)
        tf.logging.info("Perplexity = %f", perplexity)
        tf.logging.info("Finished processing evaluation at global step %d.", global_step_val)
        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)

In [81]:
eval_model(vocab_size, '/content/datalab/img2txt/train', '/content/datalab/img2txt/transformed/eval')


instances size is 1924
INFO:tensorflow:Restoring parameters from /content/datalab/img2txt/train/model.ckpt-10000
INFO:tensorflow:Computed losses for 1 of 30 batches.
INFO:tensorflow:Computed losses for 11 of 30 batches.
INFO:tensorflow:Computed losses for 21 of 30 batches.
INFO:tensorflow:Perplexity = 8.742851
INFO:tensorflow:Finished processing evaluation at global step 10000.

Predict

The prediction graph is mostly similar to train/eval graph, and they share all variables. The difference between them are:

  1. Prediction graph contains inception graph which converts image to embeddings. Therefore Prediction graph takes raw image as input.
  2. num_step is 1. The model output one word id each time.

In [46]:
import tensorflow as tf


def predict_graph(vocab_size, embedding_size=1024, lstm_size=512):
    g = tf.Graph()
    with g.as_default():
        image_feed = tf.placeholder(dtype=tf.string, shape=[], name="image_feed")
        input_feed = tf.placeholder(dtype=tf.int64, shape=[None], name="input_feed")        

        images = tf.expand_dims(image_feed, 0)
        input_seqs = tf.expand_dims(input_feed, 1)
        
        inception_embeddings = build_image_processing(images)
        inception_vars = tf.contrib.slim.get_variables_to_restore(exclude=INCEPTION_EXCLUDED_VARIABLES)    
        
        with tf.variable_scope("seq_embedding"):
            embedding_map = tf.get_variable(
                name="map",
                shape=[vocab_size, embedding_size])
            seq_embeddings = tf.nn.embedding_lookup(embedding_map, input_seqs)
        
        with tf.variable_scope("image_embedding") as scope:
            image_embeddings = tf.contrib.layers.fully_connected(
                inputs=inception_embeddings,
                num_outputs=embedding_size,
                activation_fn=None,
                biases_initializer=None,
                scope=scope)

        lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=lstm_size, state_is_tuple=True)            
        with tf.variable_scope("lstm") as lstm_scope:
            zero_state = lstm_cell.zero_state(batch_size=image_embeddings.get_shape()[0], dtype=tf.float32)
            _, initial_state = lstm_cell(image_embeddings, zero_state)
            initial_state = tf.concat(axis=1, values=initial_state)
            lstm_scope.reuse_variables()        
            tf.concat(axis=1, values=initial_state, name="initial_state")
            state_feed = tf.placeholder(dtype=tf.float32, shape=[None, sum(lstm_cell.state_size)], name="state_feed")
            state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1)
            lstm_outputs, state_tuple = lstm_cell(inputs=tf.squeeze(seq_embeddings, axis=[1]), state=state_tuple)
            lstm_state = tf.concat(axis=1, values=state_tuple, name="state")            

        with tf.variable_scope("logits") as logits_scope:
            logits = tf.contrib.layers.fully_connected(
                inputs=lstm_outputs,
                num_outputs=vocab_size,
                activation_fn=None,
                scope=logits_scope)
        
        softmax = tf.nn.softmax(logits, name="softmax")
        trainable_vars = tf.contrib.slim.get_variables_to_restore(exclude=['InceptionV3/*'])
        return g, image_feed, input_feed, state_feed, initial_state, lstm_state, softmax, inception_vars, trainable_vars

We also need a "beam search": on one extreme, if max_caption_length is 20, we will have vocab^20 results and we will pick the one with greatest probs; on the other extreme, we pick only the top word for each step, and there will be only one result, which may not be the one with greatest probs. "Beam search" keeps track of top n paths for each step, and the final results will also be n predictions.


In [70]:
import heapq
import math
import numpy as np


class Caption(object):
    """Represents a complete or partial caption."""

    def __init__(self, sentence, state, logprob, score, metadata=None):
        """Initializes the Caption.
        Args:
            sentence: List of word ids in the caption.
            state: Model state after generating the previous word.
            logprob: Log-probability of the caption.
            score: Score of the caption.
        """

        self.sentence = sentence
        self.state = state
        self.logprob = logprob
        self.score = score

    def __cmp__(self, other):
        """Compares Captions by score."""
        assert isinstance(other, Caption)
        if self.score == other.score:
            return 0
        elif self.score < other.score:
            return -1
        else:
            return 1
  
    # For Python 3 compatibility (__cmp__ is deprecated).
    def __lt__(self, other):
        assert isinstance(other, Caption)
        return self.score < other.score
  
    # Also for Python 3 compatibility.
    def __eq__(self, other):
        assert isinstance(other, Caption)
        return self.score == other.score


class TopN(object):
    """Maintains the top n elements of an incrementally provided set."""

    def __init__(self, n):
        self._n = n
        self._data = []

    def size(self):
        assert self._data is not None
        return len(self._data)

    def push(self, x):
        """Pushes a new element."""
        assert self._data is not None
        if len(self._data) < self._n:
            heapq.heappush(self._data, x)
        else:
            heapq.heappushpop(self._data, x)

    def extract(self, sort=False):
        """Extracts all elements from the TopN. This is a destructive operation.
        The only method that can be called immediately after extract() is reset().
        Args:
          sort: Whether to return the elements in descending sorted order.
        Returns:
          A list of data; the top n elements provided to the set.
        """
        assert self._data is not None
        data = self._data
        self._data = None
        if sort:
            data.sort(reverse=True)
        return data

    def reset(self):
        """Returns the TopN to an empty state."""
        self._data = []

In [90]:
from PIL import Image
from IPython.display import display
import numpy as np
import math
import os


class ShowAndTellModel(object):
    
    def __init__(self, train_dir, vocab_file, max_caption_length=20, beam_size=5):
        self._vocab = load_vocab(vocab_file)
        self._train_dir = train_dir
        self._max_caption_length = max_caption_length
        self._beam_size = beam_size
        
    def __enter__(self):
        self._graph, self._image_feed, self._input_feed, self._state_feed, self._initial_state, \
            self._lstm_state, self._softmax, inception_vars, trainable_vars = predict_graph(len(self._vocab))
        
        self._sess = tf.Session(graph=self._graph)
        
        inception_checkpoint = os.path.join(self._train_dir, 'inception_checkpoint')
        load_inception_checkpoint(self._sess, inception_vars, inception_checkpoint)
        saver = tf.train.Saver(trainable_vars)
        checkpoint_path = tf.train.latest_checkpoint(self._train_dir)
        saver.restore(self._sess, checkpoint_path)
        return self

    def __exit__(self, *args):
        self._sess.close()
        
    def _process_results(self, captions):
        id_to_word = {v: k for k, v in six.iteritems(self._vocab)}
        for caption in captions:
            words = [id_to_word[x] for x in caption.sentence]
            words = filter(lambda x: x not in ['<s>', '</s>'], words)
            yield ' '.join(words)
        
    def _predict(self, img_file):
        
        with tf.gfile.GFile(img_file, 'r') as f:
            image_bytes = f.read()

        init_state = self._sess.run(self._initial_state, feed_dict={self._image_feed: image_bytes})
        initial_beam = Caption(sentence=[self._vocab['<s>']], state=init_state[0], logprob=0.0, score=0.0)
        partial_captions = TopN(self._beam_size)
        partial_captions.push(initial_beam)
        complete_captions = TopN(self._beam_size)

        # Run beam search.
        for _ in range(self._max_caption_length - 1):
            partial_captions_list = partial_captions.extract()
            partial_captions.reset()
            input_feed_val = np.array([c.sentence[-1] for c in partial_captions_list])
            state_feed_val = np.array([c.state for c in partial_captions_list])
        
            softmax_val, new_states = self._sess.run([self._softmax, self._lstm_state],
                                               feed_dict={self._input_feed: input_feed_val, self._state_feed: state_feed_val})

            for i, partial_caption in enumerate(partial_captions_list):
                word_probabilities = softmax_val[i]
                state = new_states[i]
                # For this partial caption, get the beam_size most probable next words.
                words_and_probs = list(enumerate(word_probabilities))
                words_and_probs.sort(key=lambda x: -x[1])
                words_and_probs = words_and_probs[0:self._beam_size]
                # Each next word gives a new partial caption.
                for w, p in words_and_probs:
                    if p < 1e-12:
                        continue  # Avoid log(0).
                    sentence = partial_caption.sentence + [w]
                    logprob = partial_caption.logprob + math.log(p)
                    score = logprob
                    if w == self._vocab['</s>']:
                        beam = Caption(sentence, state, logprob, score, None)
                        complete_captions.push(beam)
                    else:
                        beam = Caption(sentence, state, logprob, score, None)
                        partial_captions.push(beam)
            if partial_captions.size() == 0:
                # We have run out of partial candidates; happens when beam_size = 1.
                break

        # If we have no complete captions then fall back to the partial captions.
        # But never output a mixture of complete and partial captions because a
        # partial caption could have a higher score than all the complete captions.
        if not complete_captions.size():
            complete_captions = partial_captions

        return complete_captions.extract(sort=True)


    def show_and_tell(self, image_file): 
        with tf.gfile.GFile(image_file) as f:
            img = Image.open(f)
            img.thumbnail((299, 299), Image.ANTIALIAS)
            display(img)
        c = self._predict(image_file)
        for r in self._process_results(c):
            print(r)

Pick the first 10 instances from test file.


In [75]:
!head /content/datalab/img2txt/transformed/test.txt


165854:COCO_train2014_000000165854.jpg
524382:COCO_val2014_000000524382.jpg
524476:COCO_train2014_000000524476.jpg
491728:COCO_train2014_000000491728.jpg
33111:COCO_train2014_000000033111.jpg
344127:COCO_train2014_000000344127.jpg
169365:COCO_train2014_000000169365.jpg
98732:COCO_train2014_000000098732.jpg
557508:COCO_train2014_000000557508.jpg
492030:COCO_train2014_000000492030.jpg

In [91]:
with ShowAndTellModel(train_dir='/content/datalab/img2txt/train',
                      vocab_file='/content/datalab/img2txt/vocab.yaml') as m:
    m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000165854.jpg')
    m.show_and_tell('/content/datalab/img2txt/images/COCO_val2014_000000524382.jpg')
    m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000524476.jpg')
    m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000491728.jpg')
    m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000033111.jpg')  
    m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000344127.jpg')  
    m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000169365.jpg')  
    m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000098732.jpg')  
    m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000557508.jpg')  
    m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000492030.jpg')


INFO:tensorflow:Restoring parameters from /content/datalab/img2txt/train/inception_checkpoint
INFO:tensorflow:Restoring parameters from /content/datalab/img2txt/train/model.ckpt-10000
a black and white cat looking at a mirror
a black and white cat is looking at the camera
a black and white cat looking at something
a black and white cat looking at its reflection in a mirror
a black and white cat is looking at a mirror
a group of dogs standing next to each other
two dogs in the grass looking at each other
a group of dogs standing in front of a building
a group of dogs standing next to a fence
a group of dogs standing next to each other in a garden
a woman holding a cat in her arms
a woman holding a white cat in her arms
a woman is holding a cat in her arms
a woman holding a cat in her arms in her arms
a woman holding a cat in her arms in a room
a cat sitting on the hood of a truck
a cat is sitting on the hood of a truck
a cat sitting on the roof of a truck
a cat is sitting on the roof of a truck
a cat sitting on the side of a truck
a black and white dog laying on top of a couch
a black and white dog laying on top of a wooden floor
a black and white dog laying on top of a bed
a black and white dog laying next to a person
a black and white dog laying on the floor
a man playing with a dog on a leash
a man is playing with a dog on a leash
a man is playing with a large dog
a man is playing with a dog in a park
a man playing with a dog with a frisbee
a cat is sitting in a bathroom sink
there is a cat that is sitting in a sink
there is a cat that is sitting in the sink
there is a cat that is laying in the sink
there is a cat that is laying in a sink
a hot dog and french fries are on a table
a hot dog and french fries are on a plate
a hot dog and french fries are on a tray
a hot dog sitting on a table next to a drink
a plate with a hot dog and a cup
a dog laying on the floor in front of a store
a dog laying on the floor in front of a mirror
a group of dogs sitting on a wooden floor
a dog laying on the floor in front of a building
a group of dogs sitting on the floor in front of a window
a group of people walking a dog on a beach
a dog on a surfboard in the ocean
a dog on a surfboard in the water
a group of people walking a dog on the beach
a man and a dog on a surfboard in the water

For fun, I would give it a try on pictures of my cats!


In [92]:
with ShowAndTellModel(train_dir='/content/datalab/img2txt/train',
                      vocab_file='/content/datalab/img2txt/vocab.yaml') as m:
    m.show_and_tell('gs://bradley-sample-notebook-data/chopin_vivaldi.jpg')
    m.show_and_tell('gs://bradley-sample-notebook-data/vivaldi_chopin_tail.jpg')
    m.show_and_tell('gs://bradley-sample-notebook-data/vivaldi.jpg')


INFO:tensorflow:Restoring parameters from /content/datalab/img2txt/train/inception_checkpoint
INFO:tensorflow:Restoring parameters from /content/datalab/img2txt/train/model.ckpt-10000
a cat laying on the floor in front of a mirror
a cat laying on the floor next to a mirror
a cat is laying on the floor next to a mirror
a cat laying on the floor next to a cat
a cat laying on the floor next to a pair of shoes
a couple of cats laying on top of a couch
a couple of cats laying on top of a bed
a couple of cats sitting on top of a couch
a couple of cats are laying on a couch
a couple of cats laying on top of a blanket
a cat that is laying down on the ground
a cat that is laying down on a carpet
a cat laying on top of a wooden floor
a cat laying on the ground next to a pair of shoes
a cat laying on the floor next to a pair of shoes

Clean up

All our files are created under /content/datalab/img2txt. So just remove that dir to clean up.


In [ ]: